import torch

def load_and_check_weights(file_path):
    # 加载模型权重
    state_dict = torch.load(file_path, map_location='cpu')
    
    for name, param in state_dict.items():
        print(f'Layer: {name}')
        print(f'Weights: {param}')

        print('---')

if __name__ == "__main__":
    epoch = 0  # 你可以根据需要修改这个值
    iteration=50
    file_path = f'../model_params/witin/quan/ResNet18_param_{epoch}.pth_iter_{iteration}.pth'
    checkpoint = torch.load(file_path)
    for key, value in checkpoint.items():
        print(f'Layer: {key}, Data type: {value.dtype}')
        print(f'Weights: {value}')

    # load_and_check_weights(file_path)